{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d27544d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from dataclasses import dataclass\n",
    "from pathlib import Path\n",
    "from typing import Dict, List, Optional, Tuple\n",
    "\n",
    "from dotenv import load_dotenv\n",
    "from openai import OpenAI\n",
    "import gradio as gr\n",
    "\n",
    "from pathlib import Path\n",
    "from typing import List, Tuple\n",
    "from transformers import AutoTokenizer\n",
    "\n",
    "\n",
    "# ---- load env ----\n",
    "load_dotenv(override=True)\n",
    "\n",
    "# ---- OpenAI-compatible base URLs (Gemini & Groq) ----\n",
    "GEMINI_BASE = \"https://generativelanguage.googleapis.com/v1beta/openai/\"\n",
    "GROQ_BASE   = \"https://api.groq.com/openai/v1\"\n",
    "\n",
    "OPENAI_API_KEY = os.getenv(\"OPENAI_API_KEY\")\n",
    "GOOGLE_API_KEY = os.getenv(\"GOOGLE_API_KEY\")  # Gemini\n",
    "GROQ_API_KEY   = os.getenv(\"GROQ_API_KEY\")    # Groq\n",
    "\n",
    "# ---- create clients only if keys exist ----\n",
    "openai_client = OpenAI() if OPENAI_API_KEY else None\n",
    "gemini_client = OpenAI(api_key=GOOGLE_API_KEY, base_url=GEMINI_BASE) if GOOGLE_API_KEY else None\n",
    "groq_client   = OpenAI(api_key=GROQ_API_KEY,   base_url=GROQ_BASE)   if GROQ_API_KEY   else None\n",
    "\n",
    "# ---- model registry (label -> client/model) ----\n",
    "MODEL_REGISTRY: Dict[str, Dict[str, object]] = {}\n",
    "def _register(label: str, client: Optional[OpenAI], model_id: str):\n",
    "    if client is not None:\n",
    "        MODEL_REGISTRY[label] = {\"client\": client, \"model\": model_id}\n",
    "\n",
    "# OpenAI\n",
    "_register(\"OpenAI • GPT-5\",        openai_client, \"gpt-5\")\n",
    "_register(\"OpenAI • GPT-5 Nano\",   openai_client, \"gpt-5-nano\")\n",
    "_register(\"OpenAI • GPT-4o-mini\",  openai_client, \"gpt-4o-mini\")\n",
    "\n",
    "# Gemini (Google)\n",
    "_register(\"Gemini • 2.5 Pro\",      gemini_client, \"gemini-2.5-pro\")\n",
    "_register(\"Gemini • 2.5 Flash\",    gemini_client, \"gemini-2.5-flash\")\n",
    "\n",
    "# Groq\n",
    "_register(\"Groq • Llama 3.1 8B\",   groq_client,   \"llama-3.1-8b-instant\")\n",
    "_register(\"Groq • Llama 3.3 70B\",  groq_client,   \"llama-3.3-70b-versatile\")\n",
    "_register(\"Groq • GPT-OSS 20B\",    groq_client,   \"openai/gpt-oss-20b\")\n",
    "_register(\"Groq • GPT-OSS 120B\",   groq_client,   \"openai/gpt-oss-120b\")\n",
    "\n",
    "AVAILABLE_MODELS = list(MODEL_REGISTRY.keys())\n",
    "DEFAULT_MODEL = AVAILABLE_MODELS[0] if AVAILABLE_MODELS else \"OpenAI • GPT-4o-mini\"\n",
    "\n",
    "print(\"Providers configured →\",\n",
    "      f\"OpenAI:{bool(OPENAI_API_KEY)}  Gemini:{bool(GOOGLE_API_KEY)}  Groq:{bool(GROQ_API_KEY)}\")\n",
    "print(\"Models available     →\", \", \".join(AVAILABLE_MODELS) or \"None (add API keys in .env)\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "efe4e4db",
   "metadata": {},
   "outputs": [],
   "source": [
    "@dataclass(frozen=True)\n",
    "class LLMRoute:\n",
    "    client: OpenAI\n",
    "    model: str\n",
    "\n",
    "class MultiLLM:\n",
    "    \"\"\"OpenAI-compatible chat across providers (OpenAI, Gemini, Groq).\"\"\"\n",
    "    def __init__(self, registry: Dict[str, Dict[str, object]]):\n",
    "        self._routes: Dict[str, LLMRoute] = {\n",
    "            k: LLMRoute(client=v[\"client\"], model=str(v[\"model\"])) for k, v in registry.items()\n",
    "        }\n",
    "        if not self._routes:\n",
    "            raise RuntimeError(\"No LLM providers configured. Add API keys in .env.\")\n",
    "\n",
    "    def complete(self, *, model_label: str, system: str, user: str) -> str:\n",
    "        if model_label not in self._routes:\n",
    "            raise ValueError(f\"Unknown model: {model_label}\")\n",
    "        r = self._routes[model_label]\n",
    "        resp = r.client.chat.completions.create(\n",
    "            model=r.model,\n",
    "            messages=[{\"role\":\"system\",\"content\":system},\n",
    "                      {\"role\":\"user\",\"content\":user}]\n",
    "        )\n",
    "        return (resp.choices[0].message.content or \"\").strip()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30636b66",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# MiniLM embedding model & tokenizer (BERT WordPiece)\n",
    "EMBED_MODEL_NAME = \"sentence-transformers/all-MiniLM-L6-v2\"\n",
    "\n",
    "# Use the model's practical window with 50% overlap\n",
    "MAX_TOKENS = 256          # all-MiniLM-L6-v2 effective limit used by Sentence-Transformers\n",
    "OVERLAP_RATIO = 0.50      # 50% sliding window overlap\n",
    "\n",
    "TOKENIZER = AutoTokenizer.from_pretrained(EMBED_MODEL_NAME)\n",
    "\n",
    "def chunk_text(\n",
    "    text: str,\n",
    "    tokenizer: AutoTokenizer = TOKENIZER,\n",
    "    max_tokens: int = MAX_TOKENS,\n",
    "    overlap_ratio: float = OVERLAP_RATIO,\n",
    ") -> List[str]:\n",
    "    \"\"\"\n",
    "    Token-aware sliding window chunking for MiniLM.\n",
    "    - Windows of `max_tokens`\n",
    "    - Step = max_tokens * (1 - overlap_ratio)  -> 50% overlap by default\n",
    "    \"\"\"\n",
    "    ids = tokenizer.encode(text, add_special_tokens=False)\n",
    "    if not ids:\n",
    "        return []\n",
    "\n",
    "    step = max(1, int(max_tokens * (1.0 - overlap_ratio)))\n",
    "    out: List[str] = []\n",
    "    for start in range(0, len(ids), step):\n",
    "        window = ids[start : start + max_tokens]\n",
    "        if not window:\n",
    "            break\n",
    "        toks = tokenizer.convert_ids_to_tokens(window)\n",
    "        chunk = tokenizer.convert_tokens_to_string(toks).strip()\n",
    "        if chunk:\n",
    "            out.append(chunk)\n",
    "        if start + max_tokens >= len(ids):\n",
    "            break\n",
    "    return out\n",
    "\n",
    "def load_bare_acts(root: str = \"knowledge_base/bare_acts\") -> List[Tuple[str, str]]:\n",
    "    \"\"\"Return list of (source_id, text). `source_id` is filename stem.\"\"\"\n",
    "    base = Path(root)\n",
    "    if not base.exists():\n",
    "        raise FileNotFoundError(f\"Folder not found: {base.resolve()}\")\n",
    "    pairs: List[Tuple[str, str]] = []\n",
    "    for p in sorted(base.glob(\"*.txt\")):\n",
    "        pairs.append((p.stem, p.read_text(encoding=\"utf-8\")))\n",
    "    if not pairs:\n",
    "        raise RuntimeError(\"No .txt files found under knowledge_base/bare_acts\")\n",
    "    return pairs\n",
    "\n",
    "acts_raw = load_bare_acts()\n",
    "print(\"Bare Acts loaded:\", [s for s, _ in acts_raw])\n",
    "print(f\"Chunking → max_tokens={MAX_TOKENS}, overlap={int(OVERLAP_RATIO*100)}%\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af537e05",
   "metadata": {},
   "outputs": [],
   "source": [
    "import chromadb\n",
    "from chromadb import PersistentClient\n",
    "from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction\n",
    "from transformers import AutoTokenizer\n",
    "from typing import Dict, List, Tuple\n",
    "\n",
    "class BareActsIndex:\n",
    "    \"\"\"Owns the vector DB lifecycle & retrieval (token-aware chunking).\"\"\"\n",
    "    def __init__(\n",
    "        self,\n",
    "        db_path: str = \"vector_db\",\n",
    "        collection: str = \"bare_acts\",\n",
    "        embed_model: str = EMBED_MODEL_NAME,\n",
    "        max_tokens: int = MAX_TOKENS,\n",
    "        overlap_ratio: float = OVERLAP_RATIO,\n",
    "    ):\n",
    "        self.db_path = db_path\n",
    "        self.collection_name = collection\n",
    "        self.embed_model = embed_model\n",
    "        self.max_tokens = max_tokens\n",
    "        self.overlap_ratio = overlap_ratio\n",
    "\n",
    "        self.embed_fn = SentenceTransformerEmbeddingFunction(model_name=self.embed_model)\n",
    "        self.tokenizer = AutoTokenizer.from_pretrained(self.embed_model)\n",
    "\n",
    "        self.client: PersistentClient = PersistentClient(path=db_path)\n",
    "        self.col = self.client.get_or_create_collection(\n",
    "            name=self.collection_name,\n",
    "            embedding_function=self.embed_fn,\n",
    "        )\n",
    "\n",
    "    def rebuild(self, docs: List[Tuple[str, str]]):\n",
    "        \"\"\"Idempotent rebuild: clears and re-adds chunks with metadata.\"\"\"\n",
    "        try:\n",
    "            self.client.delete_collection(self.collection_name)\n",
    "        except Exception:\n",
    "            pass\n",
    "\n",
    "        self.col = self.client.get_or_create_collection(\n",
    "            name=self.collection_name,\n",
    "            embedding_function=self.embed_fn,\n",
    "        )\n",
    "\n",
    "        ids, texts, metas = [], [], []\n",
    "        for src, text in docs:\n",
    "            for idx, ch in enumerate(\n",
    "                chunk_text(\n",
    "                    text,\n",
    "                    tokenizer=self.tokenizer,\n",
    "                    max_tokens=self.max_tokens,\n",
    "                    overlap_ratio=self.overlap_ratio,\n",
    "                )\n",
    "            ):\n",
    "                ids.append(f\"{src}-{idx}\")\n",
    "                texts.append(ch)\n",
    "                metas.append({\"source\": src, \"chunk_id\": idx})\n",
    "\n",
    "        if ids:\n",
    "            self.col.add(ids=ids, documents=texts, metadatas=metas)\n",
    "\n",
    "        print(\n",
    "            f\"Indexed {len(texts)} chunks from {len(docs)} files → {self.collection_name} \"\n",
    "            f\"(tokens/chunk={self.max_tokens}, overlap={int(self.overlap_ratio*100)}%)\"\n",
    "        )\n",
    "\n",
    "    def query(self, q: str, k: int = 6) -> List[Dict]:\n",
    "        res = self.col.query(query_texts=[q], n_results=k)\n",
    "        docs = res.get(\"documents\", [[]])[0]\n",
    "        metas = res.get(\"metadatas\", [[]])[0]\n",
    "        return [{\"text\": d, \"meta\": m} for d, m in zip(docs, metas)]\n",
    "\n",
    "# build (or rebuild) the index once\n",
    "index = BareActsIndex()\n",
    "index.rebuild(acts_raw)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7eec89e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "class PromptBuilder:\n",
    "    \"\"\"Small utility to keep prompting consistent and auditable.\"\"\"\n",
    "    SYSTEM = (\n",
    "        \"You are a precise legal assistant for Indian Bare Acts. \"\n",
    "        \"Answer ONLY from the provided context. If the answer is not in context, say you don't know. \"\n",
    "        \"Cite sources inline in square brackets as [file #chunk] (e.g., [bns #12]). \"\n",
    "        \"Prefer exact quotes for critical provisions/sections.\"\n",
    "    )\n",
    "\n",
    "    @staticmethod\n",
    "    def build_user(query: str, contexts: List[Dict]) -> str:\n",
    "        ctx = \"\\n\\n---\\n\\n\".join(\n",
    "            f\"[{c['meta']['source']} #{c['meta']['chunk_id']}]\\n{c['text']}\" for c in contexts\n",
    "        )\n",
    "        return (\n",
    "            f\"Question:\\n{query}\\n\\n\"\n",
    "            f\"Context (do not use outside this):\\n{ctx}\\n\\n\"\n",
    "            \"Instructions:\\n- Keep answers concise and faithful to the text.\\n\"\n",
    "            \"- Use [file #chunk] inline where relevant.\"\n",
    "        )\n",
    "\n",
    "def _snippet(txt: str, n: int = 220) -> str:\n",
    "    s = \" \".join(txt.strip().split())\n",
    "    return (s[:n] + \"…\") if len(s) > n else s\n",
    "\n",
    "class RagQAService:\n",
    "    \"\"\"Coordinates retrieval + generation, and returns a rich reference block.\"\"\"\n",
    "    def __init__(self, index: BareActsIndex, llm: MultiLLM):\n",
    "        self.index = index\n",
    "        self.llm = llm\n",
    "        self.builder = PromptBuilder()\n",
    "\n",
    "    def answer(self, *, question: str, model_label: str, k: int = 6) -> str:\n",
    "        ctx = self.index.query(question, k=k)\n",
    "        user = self.builder.build_user(question, ctx)\n",
    "        reply = self.llm.complete(model_label=model_label, system=self.builder.SYSTEM, user=user)\n",
    "\n",
    "        # Rich references: file, chunk index, snippet\n",
    "        references = \"\\n\".join(\n",
    "            f\"- [{c['meta']['source']} #{c['meta']['chunk_id']}] {_snippet(c['text'])}\"\n",
    "            for c in ctx\n",
    "        )\n",
    "        return f\"{reply}\\n\\n**References**\\n{references}\"\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4862732b",
   "metadata": {},
   "outputs": [],
   "source": [
    "llm = MultiLLM(MODEL_REGISTRY)\n",
    "qa_service = RagQAService(index=index, llm=llm)\n",
    "\n",
    "# quick smoke test (won't spend tokens if no keys for that provider)\n",
    "if AVAILABLE_MODELS:\n",
    "    print(\"Ready. Default model:\", DEFAULT_MODEL)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c0b1512b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def chat_fn(message: str, history: List[Dict], model_label: str, top_k: int) -> str:\n",
    "    try:\n",
    "        return qa_service.answer(question=message, model_label=model_label, k=int(top_k))\n",
    "    except Exception as e:\n",
    "        return f\"⚠️ {e}\"\n",
    "\n",
    "DEFAULT_QUESTION = \"Which sections deals with punishment for murder ?\"\n",
    "\n",
    "with gr.Blocks(title=\"Legal QnA • Bare Acts (RAG + Multi-LLM)\") as app:\n",
    "    gr.Markdown(\"### 🧑‍⚖️ Legal Q&A on Bare Acts (RAG) — Multi-Provider LLM\")\n",
    "    with gr.Row():\n",
    "        model_dd = gr.Dropdown(\n",
    "            choices=AVAILABLE_MODELS or [\"OpenAI • GPT-4o-mini\"],\n",
    "            value=DEFAULT_MODEL if AVAILABLE_MODELS else None,\n",
    "            label=\"Model\"\n",
    "        )\n",
    "        topk = gr.Slider(2, 12, value=6, step=1, label=\"Top-K context\")\n",
    "\n",
    "    chat = gr.ChatInterface(\n",
    "        fn=chat_fn,\n",
    "        type=\"messages\",\n",
    "        additional_inputs=[model_dd, topk],\n",
    "        textbox=gr.Textbox(\n",
    "            value=DEFAULT_QUESTION,\n",
    "            label=\"Ask a legal question\",\n",
    "            placeholder=\"Type your question about BNS/IPC/Constitution…\"\n",
    "        ),\n",
    "    )\n",
    "\n",
    "app.launch(inbrowser=True)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llm-engineering",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
